Conversation
|
I think this is elegant and useful. I'm working on some improvements to #203. Muon is optimal for linear layers, but doesn't make as much sense for e.g. |
|
One could do something like: function fun_rule(model, rule=Muon(), fallback=Adam())
skipped = Base.IdSet{Any}([model.encode.weight, model.decode.weight])
fun(x::AbstractVector) = fallback
fun(x::AbstractArray) = x in skipped ? fallback : rule
return fun
end
opt_state = Optimisers.setup(fun_rule(model), model)such that: julia> model = (;
encode=(; weight=rand(2,2)),
other=(; weight=rand(2,2), bias=rand(2)),
decode=(; weight=rand(2,2)));
julia> fun_rule(model)(model.encode.weight)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)
julia> fun_rule(model)(model.other.weight)
Muon(0.02, 0.95, 0.01, 1.0e-7, true)
julia> fun_rule(model)(model.other.bias)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)
julia> fun_rule(model)(model.decode.weight)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)I generally avoid closures, but this has a certain elegance to it. skipped = keys(IdDict([model.encode.weight, model.decode.weight] .=> nothing)) |
Quick sketch of one way to easily allow different rules for different arrays, by modifying
setup-- see docstring.PR Checklist